import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
At the beginning, I had an idea to train the model on GPU, but at the end of the day - trained on CPU (what required a looot of patience :)).
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cpu
Dataset is about various sports. Downloaded datasets are divided into three folders: train, test and valid. Each of them consists of 100 folders with names of classes (the same for all three folders), where are stored all images. Whole set consists of over 13000 pictures. I tried training a net on all of them, but after about 11h of training, I gave up and interrupted, so the net is trained on about 6000 images, split into 100 classes.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5, 0.5, 0.5)),
])
train_set = ImageFolder('C:/Users/akaga/Python_for_classes/env_for_ml_st/archive/train', transform = transform)
test_set = ImageFolder('C:/Users/akaga/Python_for_classes/env_for_ml_st/archive/test', transform = transform)
valid_set = ImageFolder('C:/Users/akaga/Python_for_classes/env_for_ml_st/archive/valid', transform = transform)
ImageFolder is a "magical" data loader, which is very useful, while loading data from various folders.
train_set
Dataset ImageFolder
Number of datapoints: 13572
Root location: C:/Users/akaga/Python_for_classes/env_for_ml_st/archive/train
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
)
A bit of processing of classes names - each name of a class is a name of a folder in which photos are stored. As we can see, quite a lot of various disciplines, from swimming to hockey.
import os
folder = 'C:/Users/akaga/Python_for_classes/env_for_ml_st/archive/train'
classes = [name for name in os.listdir(folder) if os.path.isdir(os.path.join(folder, name))]
print(classes)
['air hockey', 'ampute football', 'archery', 'arm wrestling', 'axe throwing', 'balance beam', 'barell racing', 'baseball', 'basketball', 'baton twirling', 'bike polo', 'billiards', 'bmx', 'bobsled', 'bowling', 'boxing', 'bull riding', 'bungee jumping', 'canoe slamon', 'cheerleading', 'chuckwagon racing', 'cricket', 'croquet', 'curling', 'disc golf', 'fencing', 'field hockey', 'figure skating men', 'figure skating pairs', 'figure skating women', 'fly fishing', 'football', 'formula 1 racing', 'frisbee', 'gaga', 'giant slalom', 'golf', 'hammer throw', 'hang gliding', 'harness racing', 'high jump', 'hockey', 'horse jumping', 'horse racing', 'horseshoe pitching', 'hurdles', 'hydroplane racing', 'ice climbing', 'ice yachting', 'jai alai', 'javelin', 'jousting', 'judo', 'lacrosse', 'log rolling', 'luge', 'motorcycle racing', 'mushing', 'nascar racing', 'olympic wrestling', 'parallel bar', 'pole climbing', 'pole dancing', 'pole vault', 'polo', 'pommel horse', 'rings', 'rock climbing', 'roller derby', 'rollerblade racing', 'rowing', 'rugby', 'sailboat racing', 'shot put', 'shuffleboard', 'sidecar racing', 'ski jumping', 'sky surfing', 'skydiving', 'snow boarding', 'snowmobile racing', 'speed skating', 'steer wrestling', 'sumo wrestling', 'surfing', 'swimming', 'table tennis', 'tennis', 'track bicycle', 'trapeze', 'tug of war', 'ultimate', 'uneven bars', 'volleyball', 'water cycling', 'water polo', 'weightlifting', 'wheelchair basketball', 'wheelchair racing', 'wingsuit flying']
num_classes = 100
from torchvision import models
Below function to evaluate the model, used while training to see, how is it going, if the model is not overfitting. Additional, parts from inside of that function were used later, to see the accuracy on batches from train, test and validation set, to proof, that the model is not overfitted.
def evaluation(dataloader, model):
total, correct = 0, 0
for data in dataloader:
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, pred = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (pred == labels).sum().item()
return 100 * correct / total
Model used here is a pretrained ResNet50, with CrossEntropyLoss and Adam as an optimizer. On alredy linked website, it was proposed to do 4 epochs, but it took too long for my computer, so I decided to interrupt before the end of first epoch, however, even though, the results seems to be good. Especially, as we have so many classes.
batch_size = 16
iterations = []
accuracies = []
losses = []
trainloader = DataLoader(train_set, batch_size = batch_size, shuffle = True)
valloader = DataLoader(valid_set, batch_size = batch_size, shuffle = True)
testloader = DataLoader(test_set, batch_size = batch_size, shuffle = True)
model = torchvision.models.resnet50(pretrained = True)
for param in model.parameters():
param.required_grad = False
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr = 0.0001)
loss_epoch_arr = []
max_epochs = 4
min_loss = 1000000
n_iters = np.ceil(10436/batch_size)*max_epochs
iters = 0
batch_size = 16
iterations = []
accuracies = []
losses = []
trainloader = DataLoader(train_set, batch_size = batch_size, shuffle = True)
valloader = DataLoader(valid_set, batch_size = batch_size, shuffle = True)
testloader = DataLoader(test_set, batch_size = batch_size, shuffle = True)
model = torchvision.models.resnet50(pretrained = True)
for param in model.parameters():
param.required_grad = False
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr = 0.0001)
loss_epoch_arr = []
max_epochs = 4
min_loss = 1000000
n_iters = np.ceil(10436/batch_size)*max_epochs
iters = 0
# for epoch in range(max_epochs):
# for i, data in enumerate(trainloader, 0):
# iters += 1
# inputs, labels = data
# inputs, labels = inputs.to(device), labels.to(device)
# opt.zero_grad()
# outputs = model(inputs)
# loss = loss_fn(outputs, labels)
# loss.backward()
# opt.step()
# if iters % 1 == 0:
# curAccuracy = evaluation(valloader, model)
# curLoss = loss.item()
# iterations.append(iters)
# accuracies.append(curAccuracy)
# losses.append(curLoss)
# print('Iteration: %d/%d, Loss: %0.2f, Validation acc: %0.2f'%(iters,n_iters, curLoss, curAccuracy))
# del inputs, labels, outputs
# torch.cuda.empty_cache()
# loss_epoch_arr.append(loss.item())
# print('Epoch: %d/%d ended. Validation acc: %0.2f, Train acc: %0.2f' % (
# epoch+1, max_epochs,
# evaluation(valloader, model),
# evaluation(trainloader, model)))
# print('\n\nTest Accuarcy on final model: %0.4f' % evaluation(testloader, model))
Iteration: 1/2612, Loss: 4.57, Validation acc: 1.60 Iteration: 2/2612, Loss: 4.58, Validation acc: 2.20 Iteration: 3/2612, Loss: 4.50, Validation acc: 3.20 Iteration: 4/2612, Loss: 4.46, Validation acc: 2.80 Iteration: 5/2612, Loss: 4.97, Validation acc: 5.20 Iteration: 6/2612, Loss: 4.70, Validation acc: 4.20 Iteration: 7/2612, Loss: 4.52, Validation acc: 5.60 Iteration: 8/2612, Loss: 4.45, Validation acc: 6.80 Iteration: 9/2612, Loss: 4.33, Validation acc: 6.20 Iteration: 10/2612, Loss: 4.60, Validation acc: 6.60 Iteration: 11/2612, Loss: 4.37, Validation acc: 8.20 Iteration: 12/2612, Loss: 4.60, Validation acc: 9.20 Iteration: 13/2612, Loss: 4.77, Validation acc: 8.80 Iteration: 14/2612, Loss: 4.65, Validation acc: 10.20 Iteration: 15/2612, Loss: 4.76, Validation acc: 9.60 Iteration: 16/2612, Loss: 4.22, Validation acc: 10.20 Iteration: 17/2612, Loss: 4.26, Validation acc: 10.00 Iteration: 18/2612, Loss: 4.36, Validation acc: 11.00 Iteration: 19/2612, Loss: 4.62, Validation acc: 12.20 Iteration: 20/2612, Loss: 4.80, Validation acc: 11.80 Iteration: 21/2612, Loss: 4.53, Validation acc: 12.60 Iteration: 22/2612, Loss: 4.28, Validation acc: 14.60 Iteration: 23/2612, Loss: 4.21, Validation acc: 14.20 Iteration: 24/2612, Loss: 4.27, Validation acc: 15.00 Iteration: 25/2612, Loss: 4.15, Validation acc: 16.60 Iteration: 26/2612, Loss: 4.53, Validation acc: 16.20 Iteration: 27/2612, Loss: 3.86, Validation acc: 15.60 Iteration: 28/2612, Loss: 4.26, Validation acc: 16.00 Iteration: 29/2612, Loss: 4.22, Validation acc: 16.40 Iteration: 30/2612, Loss: 4.18, Validation acc: 16.00 Iteration: 31/2612, Loss: 4.11, Validation acc: 16.60 Iteration: 32/2612, Loss: 4.22, Validation acc: 16.60 Iteration: 33/2612, Loss: 4.25, Validation acc: 16.00 Iteration: 34/2612, Loss: 3.63, Validation acc: 18.40 Iteration: 35/2612, Loss: 3.85, Validation acc: 19.60 Iteration: 36/2612, Loss: 4.11, Validation acc: 18.60 Iteration: 37/2612, Loss: 4.04, Validation acc: 18.20 Iteration: 38/2612, Loss: 4.41, Validation acc: 20.20 Iteration: 39/2612, Loss: 3.74, Validation acc: 21.00 Iteration: 40/2612, Loss: 4.31, Validation acc: 19.60 Iteration: 41/2612, Loss: 3.98, Validation acc: 23.00 Iteration: 42/2612, Loss: 3.91, Validation acc: 21.60 Iteration: 43/2612, Loss: 3.93, Validation acc: 23.80 Iteration: 44/2612, Loss: 3.79, Validation acc: 24.60 Iteration: 45/2612, Loss: 4.08, Validation acc: 23.80 Iteration: 46/2612, Loss: 4.09, Validation acc: 24.40 Iteration: 47/2612, Loss: 3.79, Validation acc: 26.00 Iteration: 48/2612, Loss: 3.61, Validation acc: 24.40 Iteration: 49/2612, Loss: 3.99, Validation acc: 25.60 Iteration: 50/2612, Loss: 3.97, Validation acc: 24.40 Iteration: 51/2612, Loss: 3.93, Validation acc: 25.40 Iteration: 52/2612, Loss: 3.60, Validation acc: 24.80 Iteration: 53/2612, Loss: 3.83, Validation acc: 25.60 Iteration: 54/2612, Loss: 3.95, Validation acc: 27.20 Iteration: 55/2612, Loss: 3.92, Validation acc: 26.00 Iteration: 56/2612, Loss: 3.32, Validation acc: 28.40 Iteration: 57/2612, Loss: 3.58, Validation acc: 26.60 Iteration: 58/2612, Loss: 3.43, Validation acc: 27.80 Iteration: 59/2612, Loss: 3.82, Validation acc: 29.40 Iteration: 60/2612, Loss: 3.31, Validation acc: 29.60 Iteration: 61/2612, Loss: 3.50, Validation acc: 30.20 Iteration: 62/2612, Loss: 3.70, Validation acc: 27.60 Iteration: 63/2612, Loss: 3.80, Validation acc: 31.20 Iteration: 64/2612, Loss: 3.68, Validation acc: 33.20 Iteration: 65/2612, Loss: 3.50, Validation acc: 32.80 Iteration: 66/2612, Loss: 3.66, Validation acc: 34.20 Iteration: 67/2612, Loss: 3.63, Validation acc: 33.20 Iteration: 68/2612, Loss: 3.12, Validation acc: 34.60 Iteration: 69/2612, Loss: 3.56, Validation acc: 36.20 Iteration: 70/2612, Loss: 3.39, Validation acc: 36.00 Iteration: 71/2612, Loss: 3.56, Validation acc: 36.00 Iteration: 72/2612, Loss: 3.59, Validation acc: 36.80 Iteration: 73/2612, Loss: 3.34, Validation acc: 36.80 Iteration: 74/2612, Loss: 3.20, Validation acc: 38.20 Iteration: 75/2612, Loss: 2.99, Validation acc: 41.80 Iteration: 76/2612, Loss: 3.50, Validation acc: 38.20 Iteration: 77/2612, Loss: 3.11, Validation acc: 40.40 Iteration: 78/2612, Loss: 3.16, Validation acc: 41.60 Iteration: 79/2612, Loss: 3.02, Validation acc: 40.80 Iteration: 80/2612, Loss: 3.45, Validation acc: 42.60 Iteration: 81/2612, Loss: 3.24, Validation acc: 41.00 Iteration: 82/2612, Loss: 3.45, Validation acc: 44.00 Iteration: 83/2612, Loss: 3.03, Validation acc: 40.40 Iteration: 84/2612, Loss: 3.11, Validation acc: 42.60 Iteration: 85/2612, Loss: 3.59, Validation acc: 42.80 Iteration: 86/2612, Loss: 2.57, Validation acc: 43.60 Iteration: 87/2612, Loss: 3.09, Validation acc: 41.40 Iteration: 88/2612, Loss: 2.61, Validation acc: 42.60 Iteration: 89/2612, Loss: 3.07, Validation acc: 43.80 Iteration: 90/2612, Loss: 2.75, Validation acc: 44.40 Iteration: 91/2612, Loss: 2.61, Validation acc: 44.60 Iteration: 92/2612, Loss: 2.34, Validation acc: 44.20 Iteration: 93/2612, Loss: 2.86, Validation acc: 43.40 Iteration: 94/2612, Loss: 2.95, Validation acc: 45.00 Iteration: 95/2612, Loss: 2.91, Validation acc: 44.40 Iteration: 96/2612, Loss: 2.87, Validation acc: 47.20 Iteration: 97/2612, Loss: 2.98, Validation acc: 45.40 Iteration: 98/2612, Loss: 2.93, Validation acc: 48.40 Iteration: 99/2612, Loss: 2.57, Validation acc: 48.40 Iteration: 100/2612, Loss: 3.19, Validation acc: 45.80 Iteration: 101/2612, Loss: 2.67, Validation acc: 48.60 Iteration: 102/2612, Loss: 2.77, Validation acc: 48.20 Iteration: 103/2612, Loss: 3.02, Validation acc: 46.00 Iteration: 104/2612, Loss: 3.42, Validation acc: 46.80 Iteration: 105/2612, Loss: 2.84, Validation acc: 46.80 Iteration: 106/2612, Loss: 2.83, Validation acc: 45.40 Iteration: 107/2612, Loss: 2.37, Validation acc: 47.60 Iteration: 108/2612, Loss: 2.63, Validation acc: 47.00 Iteration: 109/2612, Loss: 2.82, Validation acc: 46.40 Iteration: 110/2612, Loss: 2.62, Validation acc: 46.20 Iteration: 111/2612, Loss: 2.42, Validation acc: 47.80 Iteration: 112/2612, Loss: 2.45, Validation acc: 49.20 Iteration: 113/2612, Loss: 2.25, Validation acc: 49.40 Iteration: 114/2612, Loss: 2.19, Validation acc: 50.20 Iteration: 115/2612, Loss: 2.69, Validation acc: 51.60 Iteration: 116/2612, Loss: 2.59, Validation acc: 48.20 Iteration: 117/2612, Loss: 2.71, Validation acc: 51.40 Iteration: 118/2612, Loss: 2.31, Validation acc: 50.80 Iteration: 119/2612, Loss: 2.91, Validation acc: 48.80 Iteration: 120/2612, Loss: 2.60, Validation acc: 51.80 Iteration: 121/2612, Loss: 2.66, Validation acc: 50.00 Iteration: 122/2612, Loss: 2.07, Validation acc: 49.80 Iteration: 123/2612, Loss: 2.66, Validation acc: 50.00 Iteration: 124/2612, Loss: 3.03, Validation acc: 51.80 Iteration: 125/2612, Loss: 2.36, Validation acc: 52.40 Iteration: 126/2612, Loss: 2.45, Validation acc: 52.40 Iteration: 127/2612, Loss: 2.05, Validation acc: 53.00 Iteration: 128/2612, Loss: 2.33, Validation acc: 52.60 Iteration: 129/2612, Loss: 2.30, Validation acc: 54.20 Iteration: 130/2612, Loss: 2.76, Validation acc: 53.20 Iteration: 131/2612, Loss: 2.10, Validation acc: 56.00 Iteration: 132/2612, Loss: 2.31, Validation acc: 53.40 Iteration: 133/2612, Loss: 2.05, Validation acc: 55.20 Iteration: 134/2612, Loss: 2.31, Validation acc: 54.20 Iteration: 135/2612, Loss: 2.65, Validation acc: 53.00 Iteration: 136/2612, Loss: 2.62, Validation acc: 53.20 Iteration: 137/2612, Loss: 2.45, Validation acc: 53.00 Iteration: 138/2612, Loss: 2.01, Validation acc: 53.80 Iteration: 139/2612, Loss: 1.71, Validation acc: 54.20 Iteration: 140/2612, Loss: 2.39, Validation acc: 54.80 Iteration: 141/2612, Loss: 2.26, Validation acc: 55.20 Iteration: 142/2612, Loss: 2.25, Validation acc: 56.20 Iteration: 143/2612, Loss: 2.29, Validation acc: 58.20 Iteration: 144/2612, Loss: 2.28, Validation acc: 57.40 Iteration: 145/2612, Loss: 1.62, Validation acc: 57.80 Iteration: 146/2612, Loss: 1.80, Validation acc: 59.40 Iteration: 147/2612, Loss: 2.34, Validation acc: 59.00 Iteration: 148/2612, Loss: 2.10, Validation acc: 57.20 Iteration: 149/2612, Loss: 2.50, Validation acc: 59.40 Iteration: 150/2612, Loss: 1.99, Validation acc: 57.60 Iteration: 151/2612, Loss: 2.21, Validation acc: 59.60 Iteration: 152/2612, Loss: 2.38, Validation acc: 57.40 Iteration: 153/2612, Loss: 2.69, Validation acc: 59.00 Iteration: 154/2612, Loss: 1.99, Validation acc: 56.40 Iteration: 155/2612, Loss: 2.36, Validation acc: 55.40 Iteration: 156/2612, Loss: 2.18, Validation acc: 59.40 Iteration: 157/2612, Loss: 2.09, Validation acc: 59.00 Iteration: 158/2612, Loss: 2.61, Validation acc: 60.40 Iteration: 159/2612, Loss: 2.31, Validation acc: 60.00 Iteration: 160/2612, Loss: 2.28, Validation acc: 59.20 Iteration: 161/2612, Loss: 2.29, Validation acc: 61.20 Iteration: 162/2612, Loss: 1.70, Validation acc: 61.20 Iteration: 163/2612, Loss: 2.02, Validation acc: 63.60 Iteration: 164/2612, Loss: 2.08, Validation acc: 58.40 Iteration: 165/2612, Loss: 1.77, Validation acc: 60.20 Iteration: 166/2612, Loss: 1.69, Validation acc: 62.20 Iteration: 167/2612, Loss: 2.45, Validation acc: 61.80 Iteration: 168/2612, Loss: 1.74, Validation acc: 61.40 Iteration: 169/2612, Loss: 2.25, Validation acc: 60.60 Iteration: 170/2612, Loss: 1.67, Validation acc: 60.40 Iteration: 171/2612, Loss: 1.88, Validation acc: 62.40 Iteration: 172/2612, Loss: 1.47, Validation acc: 61.00 Iteration: 173/2612, Loss: 2.88, Validation acc: 61.40 Iteration: 174/2612, Loss: 1.64, Validation acc: 62.80 Iteration: 175/2612, Loss: 2.31, Validation acc: 65.20 Iteration: 176/2612, Loss: 2.26, Validation acc: 63.80 Iteration: 177/2612, Loss: 1.24, Validation acc: 62.00 Iteration: 178/2612, Loss: 2.01, Validation acc: 64.20 Iteration: 179/2612, Loss: 1.84, Validation acc: 63.40 Iteration: 180/2612, Loss: 1.60, Validation acc: 63.60 Iteration: 181/2612, Loss: 1.61, Validation acc: 63.20 Iteration: 182/2612, Loss: 1.99, Validation acc: 65.60 Iteration: 183/2612, Loss: 2.24, Validation acc: 64.40 Iteration: 184/2612, Loss: 2.06, Validation acc: 64.20 Iteration: 185/2612, Loss: 1.63, Validation acc: 63.20 Iteration: 186/2612, Loss: 1.55, Validation acc: 65.60 Iteration: 187/2612, Loss: 1.75, Validation acc: 64.60 Iteration: 188/2612, Loss: 1.46, Validation acc: 66.00 Iteration: 189/2612, Loss: 1.53, Validation acc: 66.40 Iteration: 190/2612, Loss: 1.32, Validation acc: 66.40 Iteration: 191/2612, Loss: 1.61, Validation acc: 65.60 Iteration: 192/2612, Loss: 1.57, Validation acc: 67.00 Iteration: 193/2612, Loss: 2.05, Validation acc: 64.60 Iteration: 194/2612, Loss: 1.94, Validation acc: 65.20 Iteration: 195/2612, Loss: 1.70, Validation acc: 66.60 Iteration: 196/2612, Loss: 1.65, Validation acc: 66.20 Iteration: 197/2612, Loss: 1.44, Validation acc: 67.60 Iteration: 198/2612, Loss: 1.36, Validation acc: 67.20 Iteration: 199/2612, Loss: 1.86, Validation acc: 71.00 Iteration: 200/2612, Loss: 1.84, Validation acc: 66.40 Iteration: 201/2612, Loss: 1.71, Validation acc: 69.80 Iteration: 202/2612, Loss: 1.59, Validation acc: 69.40 Iteration: 203/2612, Loss: 1.74, Validation acc: 69.20 Iteration: 204/2612, Loss: 1.76, Validation acc: 69.00 Iteration: 205/2612, Loss: 1.60, Validation acc: 67.40 Iteration: 206/2612, Loss: 1.48, Validation acc: 67.80 Iteration: 207/2612, Loss: 1.61, Validation acc: 70.60 Iteration: 208/2612, Loss: 2.05, Validation acc: 69.80 Iteration: 209/2612, Loss: 1.63, Validation acc: 69.60 Iteration: 210/2612, Loss: 1.88, Validation acc: 69.00 Iteration: 211/2612, Loss: 1.70, Validation acc: 70.00 Iteration: 212/2612, Loss: 1.71, Validation acc: 70.60 Iteration: 213/2612, Loss: 1.18, Validation acc: 71.20 Iteration: 214/2612, Loss: 1.62, Validation acc: 68.00 Iteration: 215/2612, Loss: 1.58, Validation acc: 71.00 Iteration: 216/2612, Loss: 1.26, Validation acc: 70.00 Iteration: 217/2612, Loss: 2.30, Validation acc: 68.60 Iteration: 218/2612, Loss: 2.05, Validation acc: 70.00 Iteration: 219/2612, Loss: 1.43, Validation acc: 70.00 Iteration: 220/2612, Loss: 1.70, Validation acc: 68.80 Iteration: 221/2612, Loss: 1.13, Validation acc: 68.60 Iteration: 222/2612, Loss: 1.56, Validation acc: 70.00 Iteration: 223/2612, Loss: 1.66, Validation acc: 68.80 Iteration: 224/2612, Loss: 1.50, Validation acc: 70.40 Iteration: 225/2612, Loss: 1.35, Validation acc: 68.80 Iteration: 226/2612, Loss: 1.43, Validation acc: 66.00 Iteration: 227/2612, Loss: 1.76, Validation acc: 68.40 Iteration: 228/2612, Loss: 1.51, Validation acc: 70.00 Iteration: 229/2612, Loss: 1.78, Validation acc: 69.60 Iteration: 230/2612, Loss: 1.50, Validation acc: 72.40 Iteration: 231/2612, Loss: 1.83, Validation acc: 68.80 Iteration: 232/2612, Loss: 1.69, Validation acc: 71.60 Iteration: 233/2612, Loss: 1.64, Validation acc: 72.20 Iteration: 234/2612, Loss: 1.45, Validation acc: 67.80 Iteration: 235/2612, Loss: 1.64, Validation acc: 70.60 Iteration: 236/2612, Loss: 1.52, Validation acc: 70.80 Iteration: 237/2612, Loss: 1.11, Validation acc: 75.20 Iteration: 238/2612, Loss: 2.04, Validation acc: 72.20 Iteration: 239/2612, Loss: 1.62, Validation acc: 72.60 Iteration: 240/2612, Loss: 2.07, Validation acc: 71.60 Iteration: 241/2612, Loss: 1.38, Validation acc: 73.80 Iteration: 242/2612, Loss: 1.77, Validation acc: 74.40 Iteration: 243/2612, Loss: 1.16, Validation acc: 73.40 Iteration: 244/2612, Loss: 1.44, Validation acc: 70.60 Iteration: 245/2612, Loss: 1.25, Validation acc: 71.20 Iteration: 246/2612, Loss: 1.43, Validation acc: 71.40 Iteration: 247/2612, Loss: 1.94, Validation acc: 73.60 Iteration: 248/2612, Loss: 1.22, Validation acc: 72.80 Iteration: 249/2612, Loss: 1.79, Validation acc: 72.40 Iteration: 250/2612, Loss: 1.80, Validation acc: 73.00 Iteration: 251/2612, Loss: 1.38, Validation acc: 74.20 Iteration: 252/2612, Loss: 1.16, Validation acc: 75.40 Iteration: 253/2612, Loss: 1.28, Validation acc: 73.80 Iteration: 254/2612, Loss: 1.49, Validation acc: 71.40 Iteration: 255/2612, Loss: 1.04, Validation acc: 73.40 Iteration: 256/2612, Loss: 1.05, Validation acc: 74.60 Iteration: 257/2612, Loss: 1.98, Validation acc: 73.20 Iteration: 258/2612, Loss: 1.18, Validation acc: 72.00 Iteration: 259/2612, Loss: 1.62, Validation acc: 73.40 Iteration: 260/2612, Loss: 1.64, Validation acc: 72.60 Iteration: 261/2612, Loss: 1.45, Validation acc: 74.60 Iteration: 262/2612, Loss: 1.68, Validation acc: 76.20 Iteration: 263/2612, Loss: 1.54, Validation acc: 75.20 Iteration: 264/2612, Loss: 1.57, Validation acc: 73.80 Iteration: 265/2612, Loss: 1.71, Validation acc: 74.20 Iteration: 266/2612, Loss: 1.14, Validation acc: 75.80 Iteration: 267/2612, Loss: 1.62, Validation acc: 76.00 Iteration: 268/2612, Loss: 1.40, Validation acc: 75.20 Iteration: 269/2612, Loss: 1.48, Validation acc: 74.60 Iteration: 270/2612, Loss: 1.29, Validation acc: 76.00 Iteration: 271/2612, Loss: 1.34, Validation acc: 77.80 Iteration: 272/2612, Loss: 1.23, Validation acc: 74.80 Iteration: 273/2612, Loss: 1.00, Validation acc: 75.20 Iteration: 274/2612, Loss: 0.90, Validation acc: 78.00 Iteration: 275/2612, Loss: 1.90, Validation acc: 78.20 Iteration: 276/2612, Loss: 1.21, Validation acc: 76.80 Iteration: 277/2612, Loss: 1.49, Validation acc: 76.60 Iteration: 278/2612, Loss: 1.28, Validation acc: 76.40 Iteration: 279/2612, Loss: 1.39, Validation acc: 76.40 Iteration: 280/2612, Loss: 1.40, Validation acc: 76.80 Iteration: 281/2612, Loss: 1.22, Validation acc: 78.00 Iteration: 282/2612, Loss: 1.23, Validation acc: 74.40 Iteration: 283/2612, Loss: 1.26, Validation acc: 77.60 Iteration: 284/2612, Loss: 1.32, Validation acc: 77.60 Iteration: 285/2612, Loss: 1.26, Validation acc: 77.40 Iteration: 286/2612, Loss: 1.72, Validation acc: 77.20 Iteration: 287/2612, Loss: 1.37, Validation acc: 78.40 Iteration: 288/2612, Loss: 0.99, Validation acc: 76.00 Iteration: 289/2612, Loss: 1.42, Validation acc: 76.40 Iteration: 290/2612, Loss: 1.56, Validation acc: 79.20 Iteration: 291/2612, Loss: 1.24, Validation acc: 75.80 Iteration: 292/2612, Loss: 1.04, Validation acc: 74.80 Iteration: 293/2612, Loss: 1.17, Validation acc: 76.20 Iteration: 294/2612, Loss: 0.91, Validation acc: 76.80 Iteration: 295/2612, Loss: 1.82, Validation acc: 76.40 Iteration: 296/2612, Loss: 1.04, Validation acc: 77.00 Iteration: 297/2612, Loss: 1.60, Validation acc: 74.00 Iteration: 298/2612, Loss: 1.53, Validation acc: 75.20 Iteration: 299/2612, Loss: 1.28, Validation acc: 76.20 Iteration: 300/2612, Loss: 1.14, Validation acc: 75.80 Iteration: 301/2612, Loss: 1.20, Validation acc: 74.20 Iteration: 302/2612, Loss: 1.34, Validation acc: 75.60 Iteration: 303/2612, Loss: 1.32, Validation acc: 77.00 Iteration: 304/2612, Loss: 1.52, Validation acc: 77.40 Iteration: 305/2612, Loss: 1.33, Validation acc: 76.20 Iteration: 306/2612, Loss: 1.27, Validation acc: 77.00 Iteration: 307/2612, Loss: 1.25, Validation acc: 76.80 Iteration: 308/2612, Loss: 0.96, Validation acc: 78.20 Iteration: 309/2612, Loss: 1.29, Validation acc: 77.00 Iteration: 310/2612, Loss: 1.58, Validation acc: 76.80 Iteration: 311/2612, Loss: 0.96, Validation acc: 79.20 Iteration: 312/2612, Loss: 0.74, Validation acc: 77.80 Iteration: 313/2612, Loss: 1.10, Validation acc: 75.20 Iteration: 314/2612, Loss: 1.57, Validation acc: 78.20 Iteration: 315/2612, Loss: 1.21, Validation acc: 76.20 Iteration: 316/2612, Loss: 1.44, Validation acc: 79.00 Iteration: 317/2612, Loss: 1.26, Validation acc: 80.60 Iteration: 318/2612, Loss: 1.50, Validation acc: 77.60 Iteration: 319/2612, Loss: 0.92, Validation acc: 76.80 Iteration: 320/2612, Loss: 1.98, Validation acc: 77.80 Iteration: 321/2612, Loss: 0.92, Validation acc: 78.40 Iteration: 322/2612, Loss: 1.18, Validation acc: 77.80 Iteration: 323/2612, Loss: 1.54, Validation acc: 74.00 Iteration: 324/2612, Loss: 1.26, Validation acc: 77.20 Iteration: 325/2612, Loss: 1.13, Validation acc: 78.60 Iteration: 326/2612, Loss: 1.20, Validation acc: 77.20 Iteration: 327/2612, Loss: 1.15, Validation acc: 76.00 Iteration: 328/2612, Loss: 1.14, Validation acc: 79.00 Iteration: 329/2612, Loss: 1.60, Validation acc: 76.60 Iteration: 330/2612, Loss: 1.18, Validation acc: 78.00 Iteration: 331/2612, Loss: 1.14, Validation acc: 80.00 Iteration: 332/2612, Loss: 1.21, Validation acc: 79.40 Iteration: 333/2612, Loss: 1.15, Validation acc: 81.20 Iteration: 334/2612, Loss: 0.99, Validation acc: 81.60 Iteration: 335/2612, Loss: 1.55, Validation acc: 80.80 Iteration: 336/2612, Loss: 1.14, Validation acc: 79.20 Iteration: 337/2612, Loss: 1.85, Validation acc: 81.40 Iteration: 338/2612, Loss: 1.56, Validation acc: 81.60 Iteration: 339/2612, Loss: 1.51, Validation acc: 80.40 Iteration: 340/2612, Loss: 0.95, Validation acc: 79.40 Iteration: 341/2612, Loss: 0.92, Validation acc: 79.80 Iteration: 342/2612, Loss: 0.79, Validation acc: 79.60 Iteration: 343/2612, Loss: 1.22, Validation acc: 79.60 Iteration: 344/2612, Loss: 0.89, Validation acc: 81.00 Iteration: 345/2612, Loss: 0.74, Validation acc: 82.60 Iteration: 346/2612, Loss: 1.06, Validation acc: 78.60 Iteration: 347/2612, Loss: 1.16, Validation acc: 80.20 Iteration: 348/2612, Loss: 0.93, Validation acc: 79.00 Iteration: 349/2612, Loss: 1.39, Validation acc: 81.80 Iteration: 350/2612, Loss: 1.10, Validation acc: 78.40 Iteration: 351/2612, Loss: 0.95, Validation acc: 80.80 Iteration: 352/2612, Loss: 0.91, Validation acc: 80.40 Iteration: 353/2612, Loss: 1.65, Validation acc: 80.40 Iteration: 354/2612, Loss: 1.02, Validation acc: 77.60 Iteration: 355/2612, Loss: 1.07, Validation acc: 78.80 Iteration: 356/2612, Loss: 1.15, Validation acc: 81.60 Iteration: 357/2612, Loss: 1.45, Validation acc: 80.40 Iteration: 358/2612, Loss: 0.76, Validation acc: 77.80 Iteration: 359/2612, Loss: 1.54, Validation acc: 79.80 Iteration: 360/2612, Loss: 0.62, Validation acc: 79.00 Iteration: 361/2612, Loss: 1.50, Validation acc: 78.00 Iteration: 362/2612, Loss: 0.85, Validation acc: 77.60 Iteration: 363/2612, Loss: 0.79, Validation acc: 77.20 Iteration: 364/2612, Loss: 1.39, Validation acc: 78.40 Iteration: 365/2612, Loss: 1.31, Validation acc: 79.20 Iteration: 366/2612, Loss: 0.76, Validation acc: 79.00 Iteration: 367/2612, Loss: 1.44, Validation acc: 80.40 Iteration: 368/2612, Loss: 1.16, Validation acc: 79.20 Iteration: 369/2612, Loss: 1.00, Validation acc: 81.20 Iteration: 370/2612, Loss: 1.08, Validation acc: 79.20 Iteration: 371/2612, Loss: 0.77, Validation acc: 79.20 Iteration: 372/2612, Loss: 1.29, Validation acc: 78.40 Iteration: 373/2612, Loss: 1.23, Validation acc: 79.40 Iteration: 374/2612, Loss: 1.01, Validation acc: 79.20 Iteration: 375/2612, Loss: 1.01, Validation acc: 80.60 Iteration: 376/2612, Loss: 1.25, Validation acc: 79.80 Iteration: 377/2612, Loss: 1.37, Validation acc: 79.40 Iteration: 378/2612, Loss: 0.88, Validation acc: 81.20 Iteration: 379/2612, Loss: 1.07, Validation acc: 77.80 Iteration: 380/2612, Loss: 0.94, Validation acc: 79.40 Iteration: 381/2612, Loss: 0.79, Validation acc: 81.60 Iteration: 382/2612, Loss: 1.14, Validation acc: 81.00 Iteration: 383/2612, Loss: 0.93, Validation acc: 79.40 Iteration: 384/2612, Loss: 0.84, Validation acc: 80.80 Iteration: 385/2612, Loss: 0.78, Validation acc: 78.80 Iteration: 386/2612, Loss: 0.99, Validation acc: 81.60 Iteration: 387/2612, Loss: 0.93, Validation acc: 80.20 Iteration: 388/2612, Loss: 1.21, Validation acc: 78.60 Iteration: 389/2612, Loss: 0.65, Validation acc: 80.80 Iteration: 390/2612, Loss: 1.15, Validation acc: 82.20 Iteration: 391/2612, Loss: 0.83, Validation acc: 81.40 Iteration: 392/2612, Loss: 0.78, Validation acc: 79.40 Iteration: 393/2612, Loss: 1.56, Validation acc: 81.20 Iteration: 394/2612, Loss: 0.86, Validation acc: 80.40 Iteration: 395/2612, Loss: 0.77, Validation acc: 82.20 Iteration: 396/2612, Loss: 0.73, Validation acc: 80.00
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) ~\AppData\Local\Temp/ipykernel_4896/3612283344.py in <module> 44 45 if iters % 1 == 0: ---> 46 curAccuracy = evaluation(valloader, model) 47 curLoss = loss.item() 48 iterations.append(iters) ~\AppData\Local\Temp/ipykernel_4896/3793694988.py in evaluation(dataloader, model) 4 inputs, labels = data 5 inputs, labels = inputs.to(device), labels.to(device) ----> 6 outputs = model(inputs) 7 _, pred = torch.max(outputs.data, 1) 8 total += labels.size(0) c:\Users\akaga\Python_for_classes\env_for_ml_st\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], [] c:\Users\akaga\Python_for_classes\env_for_ml_st\lib\site-packages\torchvision\models\resnet.py in forward(self, x) 247 248 def forward(self, x: Tensor) -> Tensor: --> 249 return self._forward_impl(x) 250 251 c:\Users\akaga\Python_for_classes\env_for_ml_st\lib\site-packages\torchvision\models\resnet.py in _forward_impl(self, x) 236 237 x = self.layer1(x) --> 238 x = self.layer2(x) 239 x = self.layer3(x) 240 x = self.layer4(x) c:\Users\akaga\Python_for_classes\env_for_ml_st\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], [] c:\Users\akaga\Python_for_classes\env_for_ml_st\lib\site-packages\torch\nn\modules\container.py in forward(self, input) 139 def forward(self, input): 140 for module in self: --> 141 input = module(input) 142 return input 143 c:\Users\akaga\Python_for_classes\env_for_ml_st\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], [] c:\Users\akaga\Python_for_classes\env_for_ml_st\lib\site-packages\torchvision\models\resnet.py in forward(self, x) 122 identity = x 123 --> 124 out = self.conv1(x) 125 out = self.bn1(out) 126 out = self.relu(out) c:\Users\akaga\Python_for_classes\env_for_ml_st\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101 or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], [] c:\Users\akaga\Python_for_classes\env_for_ml_st\lib\site-packages\torch\nn\modules\conv.py in forward(self, input) 444 445 def forward(self, input: Tensor) -> Tensor: --> 446 return self._conv_forward(input, self.weight, self.bias) 447 448 class Conv3d(_ConvNd): c:\Users\akaga\Python_for_classes\env_for_ml_st\lib\site-packages\torch\nn\modules\conv.py in _conv_forward(self, input, weight, bias) 440 weight, bias, self.stride, 441 _pair(0), self.dilation, self.groups) --> 442 return F.conv2d(input, weight, bias, self.stride, 443 self.padding, self.dilation, self.groups) 444 KeyboardInterrupt:
In case anything goes wrong while doing this homework (and to later be able to open it fast), I wrote to file both trained model and informations from first batch from testloader - images (inputs), labels and predicted labels. It also helped me to make the pictures, which I explain, everytime the same, so I was able to draw some conclusions.
filename = 'finalized_model10.pth'
torch.save(model, filename)
filename = 'finalized_model10.pth'
loaded_model=torch.load(filename)
type(loaded_model)
torchvision.models.resnet.ResNet
Checking scores on validation, train and test set. Model seems not to be overfitted.
# score on validation set
# total=0
# correct=0
# inputs, labels = next(iter(valloader))
# inputs, labels = inputs.to(device), labels.to(device)
# outputs = loaded_model(inputs)
# _, pred = torch.max(outputs.data, 1)
# total += labels.size(0)
# correct += (pred == labels).sum().item()
# print(100 * correct / total)
81.25
#score on train set
# total=0
# correct=0
# inputs, labels = next(iter(trainloader))
# inputs, labels = inputs.to(device), labels.to(device)
# outputs = loaded_model(inputs)
# _, pred = torch.max(outputs.data, 1)
# total += labels.size(0)
# correct += (pred == labels).sum().item()
# print(100 * correct / total)
81.25
#score on test set
# total=0
# correct=0
# inputs, labels = next(iter(testloader))
# inputs, labels = inputs.to(device), labels.to(device)
# outputs = loaded_model(inputs)
# _, pred = torch.max(outputs.data, 1)
# total += labels.size(0)
# correct += (pred == labels).sum().item()
# print(100 * correct / total)
75.0
# #save inputs, labels and prediction
# torch.save(inputs, 'input_to_analyze.pt')
# torch.save(labels, 'labels_to_analyze.pt')
# torch.save(pred, 'pred_to_analyze.pt')
#reading from files
inputs=torch.load('input_to_analyze.pt')
labels=torch.load('labels_to_analyze.pt')
pred=torch.load('pred_to_analyze.pt')
Additionaly, the size of photos is 224 x 224, and the total number of images in train set is over 13000 (however, as stated before, not all were used).
inputs.size()
torch.Size([16, 3, 224, 224])
Pictures and both true and predicted labels for earlier saved batch from test set. Images seems to be in good resolution, in most cases, we can easily say which sport is on the picture. However, due to the normalization at the beginning, some of the photos seems to be a bit too dark. In this batch, we have 12 correct predicted and 4 false labels.
print(inputs.shape)
print(labels)
print(pred)
grid_testing = torchvision.utils.make_grid(inputs, nrow=8)
plt.figure(figsize=(15,15))
plt.imshow(grid_testing.permute(1,2,0))
for i in range(batch_size):
print(f"label: {classes[labels[i]]}, pred: {classes[pred[i]]}")
torch.Size([16, 3, 224, 224]) tensor([66, 69, 98, 63, 50, 41, 61, 18, 79, 98, 60, 20, 91, 92, 7, 36]) tensor([66, 69, 98, 62, 50, 41, 61, 84, 47, 88, 60, 20, 91, 92, 7, 36])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
label: rings, pred: rings label: rollerblade racing, pred: rollerblade racing label: wheelchair racing, pred: wheelchair racing label: pole vault, pred: pole dancing label: javelin, pred: javelin label: hockey, pred: hockey label: pole climbing, pred: pole climbing label: canoe slamon, pred: surfing label: snow boarding, pred: ice climbing label: wheelchair racing, pred: track bicycle label: parallel bar, pred: parallel bar label: chuckwagon racing, pred: chuckwagon racing label: ultimate, pred: ultimate label: uneven bars, pred: uneven bars label: baseball, pred: baseball label: golf, pred: golf
In some of the cases, I think I would also have problems with recognizing, what is happening on the pictures, for example on a picture in bottom row, between horses and gimnastics girl. Additionaly, for the second one (rollerblade racing), I was surprised, that the model did it good, as the rolls are hardly visible. What is interesting, model made a mistake with the picture with snowboard - for me it is clearly visible, so later I will take a look at the explanation of a model - maybe it will help me understand, why it decided to predict 'ice climbing' instead of 'snowboarding'.
print(f"batch number: {len(trainloader)}, batch size: {batch_size}, all together, number of images in train set: {len(trainloader)*batch_size}") #trainloader contains len(trainloader) batches, each batch has 16 images
batch number: 849, batch size: 16, all together, number of images in train set: 13584
Below I present 6 good classified photos, with their explanation. Each of them was explained with the same three methods. I might have used function to do the calculations and printing, as they were in all cases similar, but I decided not to, as in this way, results of each part of code are shown directly below it, so this will help in coming back to this code while doing a project. Methods used are: LIME, IntegratedGradients and SHAP.
print(type(inputs[1]))
grid_test = torchvision.utils.make_grid(inputs[0], nrow=10)
plt.figure(figsize=(10,10))
plt.imshow(grid_test.cpu().permute(1,2,0))
print(f'true label: {classes[labels[0]]}')
print(f'pred label: {classes[pred[0]]}')
print(labels[0])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<class 'torch.Tensor'> true label: rings pred label: rings tensor(66)
For me very dark picture, however still visible. I was curious, what made model think it is sport called 'rings'. I thought it were this round rings, but let's see...
#explanations
from captum.attr import Lime
from captum.attr import IntegratedGradients
explainer_0 = Lime(loaded_model)
mask_0 = segmentation.quickshift(
inputs[0].permute(1, 2, 0).double(),
kernel_size=25,
max_dist=7,
ratio=0.7
)
mask_0
array([[ 47, 47, 47, ..., 72, 72, 72],
[ 47, 47, 47, ..., 72, 72, 72],
[ 47, 47, 47, ..., 72, 72, 72],
...,
[ 961, 961, 961, ..., 1014, 1014, 1014],
[ 961, 961, 961, ..., 1014, 1014, 1014],
[ 961, 961, 961, ..., 1014, 1014, 1014]], dtype=int64)
attr_0 = explainer_0.attribute(
inputs[0].unsqueeze(0),
target=66,
n_samples=20,
feature_mask=torch.as_tensor(mask_0),
show_progress=True
)
Lime attribution: 100%|██████████| 20/20 [00:03<00:00, 5.41it/s]
np.max(attr_0.tolist())
0.007658614311367273
show_image_mask_explanation(inputs[0], mask_0, attr_0[0].mean(axis=0)*100000)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
show_attr(attr_0[0])
According to the LIME explanation, it seems as this prediction was some kind of random guess, as the explanation shows something in the surrounding of part of a feet of a men, for sure not rings.
exp_ig_0 = IntegratedGradients(loaded_model)
attr_ig_0 = exp_ig_0.attribute(inputs[0].unsqueeze(0), target=66)
show_attr(attr_ig_0[0])
For IntegratedGradients still not so much information (but much more than for LIME). There are no 'special places', which, as a whole part, made model to decide. What is interesting, the rings can be seen on the picture, but as a neutral surface, the one that has neither positive nor negative influence on the prediction.
from captum.attr import KernelShap
exp_ks_0_2 = KernelShap(loaded_model)
exp_ks_0_2 = exp_ks_0_2.attribute(
inputs[0].unsqueeze(0),
target=66,
n_samples=300,
feature_mask=torch.as_tensor(mask_0),
show_progress=True
)
show_attr(exp_ks_0_2[0])
Kernel Shap attribution: 100%|██████████| 300/300 [00:48<00:00, 6.24it/s]
Only SHAP seems to be explaining the images, as we would. It seems to say something about rings (green visible parts of both rings) and position of a body (horizontal, not vertical).
grid_test = torchvision.utils.make_grid(inputs[2], nrow=10)
plt.figure(figsize=(10,10))
plt.imshow(grid_test.cpu().permute(1,2,0))
print(f'true label: {classes[labels[2]]}')
print(f'true label: {classes[pred[2]]}')
print(labels[2])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
true label: wheelchair racing true label: wheelchair racing tensor(98)
This picture seems to be characteristic, as it clearly shows people on wheelschairs and doing some sports, so I was thinking, that wheels will be the most important part for model to make the decision.
from skimage import segmentation
explainer_1 = Lime(loaded_model)
mask_1 = segmentation.quickshift(
inputs[2].permute(1, 2, 0).double(),
kernel_size=20,
max_dist=7,
ratio=0.5
)
mask_1
array([[ 12, 12, 12, ..., 7, 7, 7],
[ 12, 12, 12, ..., 15, 15, 7],
[ 12, 12, 12, ..., 15, 15, 15],
...,
[1188, 1188, 1188, ..., 1195, 1195, 1195],
[1188, 1188, 1188, ..., 1195, 1195, 1195],
[1188, 1188, 1188, ..., 1195, 1195, 1195]], dtype=int64)
attr_1 = explainer_1.attribute(
inputs[2].unsqueeze(0),
target=98,
n_samples=20,
feature_mask=torch.as_tensor(mask_1),
show_progress=True
)
Lime attribution: 100%|██████████| 20/20 [00:03<00:00, 5.86it/s]
np.max(attr_1.tolist())
0.022972138598561287
def show_image_mask_explanation(image, mask, explanation):
fig, ax = plt.subplots(1, 3, figsize=[6 * 2, 6])
ax[0].imshow(image.permute(1, 2, 0))
ax[0].set_title("image")
ax[1].imshow(mask, cmap="flag")
ax[1].set_title("segmentation mask")
ax[2].imshow(explanation, vmin=-1, vmax=1, cmap="RdBu")
ax[2].set_title("explanation")
plt.show()
show_image_mask_explanation(inputs[2], mask_1, attr_1[0].mean(axis=0)*1000000)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
from captum.attr import visualization
def show_attr(attr_map):
visualization.visualize_image_attr(
attr_map.permute(1, 2, 0).numpy(),
method='heat_map',
sign='all',
show_colorbar=True
)
show_attr(attr_1[0])
As previously assumed, even LIME showed, that wheel, the one in the main part of the picture, the biggest, had the most positive influence, to classify this image as 'wheelchair racing'.
from captum.attr import IntegratedGradients
exp_ig_1 = IntegratedGradients(loaded_model)
attr_ig_1 = exp_ig_1.attribute(inputs[2].unsqueeze(0), target=98)
show_attr(attr_ig_1[0])
For IntegratedGradients method, we can also see shape of a wheel, even being green, as having possitive influence, however, the picture is chaotic.
from captum.attr import KernelShap
exp_ks_1_1 = KernelShap(loaded_model)
exp_ks_1_1 = exp_ks_1_1.attribute(
inputs[2].unsqueeze(0),
target=98,
n_samples=300,
feature_mask=torch.as_tensor(mask_1),
show_progress=True
)
show_attr(exp_ks_1_1[0])
Kernel Shap attribution: 100%|██████████| 300/300 [00:45<00:00, 6.52it/s]
In this case also in SHAP method we can clearly see even more than one wheel, all coloured in green, so having positive impact. I think it is one of the best results of explanations done for this model in this file.
People doing 'rollerblade racing'. I was surprised, that the model was right for this picture, as the rollers are hardly visible. I hoped, that the explanation will show the rollers...
grid_test = torchvision.utils.make_grid(inputs[1], nrow=10)
plt.figure(figsize=(10,10))
plt.imshow(grid_test.cpu().permute(1,2,0))
print(f'true label: {classes[labels[1]]}')
print(f'true label: {classes[pred[1]]}')
print(labels[1])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
true label: rollerblade racing true label: rollerblade racing tensor(69)
explainer_3= Lime(loaded_model)
mask_3 = segmentation.quickshift(
inputs[1].permute(1, 2, 0).double(),
kernel_size=25,
max_dist=7,
ratio=0.7
)
attr_3 = explainer_3.attribute(
inputs[1].unsqueeze(0),
target=69,
n_samples=20,
feature_mask=torch.as_tensor(mask_3),
show_progress=True
)
Lime attribution: 100%|██████████| 20/20 [00:03<00:00, 5.43it/s]
np.max(attr_3.tolist())
0.0
show_image_mask_explanation(inputs[1], mask_3, attr_3[0].mean(axis=0)*1000000)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
For LIME method, it did not even work, as if model was not even trying to classify this picture to a class, based on some kind of observation on an image, but just guessed or so.
exp_ig_3 = IntegratedGradients(loaded_model)
attr_ig_3 = exp_ig_3.attribute(inputs[1].unsqueeze(0), target=69)
show_attr(attr_ig_3[0])
For this method, we can hardly see anything, when we know, what is on that picture, we could think of straight lines, being legs of mens, but that is it.
exp_ks_3_1 = KernelShap(loaded_model)
exp_ks_3_1 = exp_ks_3_1.attribute(
inputs[1].unsqueeze(0),
target=69,
n_samples=500,
feature_mask=torch.as_tensor(mask_3),
show_progress=True
)
show_attr(exp_ks_3_1[0])
Kernel Shap attribution: 100%|██████████| 500/500 [01:21<00:00, 6.13it/s]
Even explanation of a SHAP for this method, does not seems to make much of a sense. We can see some parts of people, for sure nothing, which seems to be rollers. For me it was a bit disappointing example.
This one seems to on one hand, have a potential of being good explained, as the javelin is clearly visible, however, on the other hand - javelin is small on the picture.
grid_test = torchvision.utils.make_grid(inputs[4], nrow=10)
plt.figure(figsize=(10,10))
plt.imshow(grid_test.cpu().permute(1,2,0))
print(f'true label: {classes[labels[4]]}')
print(f'true label: {classes[pred[4]]}')
print(labels[4])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
true label: javelin true label: javelin tensor(50)
from skimage import segmentation
explainer_4= Lime(loaded_model)
mask_4 = segmentation.quickshift(
inputs[4].permute(1, 2, 0).double(),
kernel_size=25,
max_dist=7,
ratio=0.7
)
attr_4 = explainer_4.attribute(
inputs[4].unsqueeze(0),
target=50,
n_samples=20,
feature_mask=torch.as_tensor(mask_4),
show_progress=True
)
np.max(attr_4.tolist())
Lime attribution: 100%|██████████| 20/20 [00:05<00:00, 3.79it/s]
0.0
show_image_mask_explanation(inputs[4], mask_4, attr_4[0].mean(axis=0)*1000000)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Also in this case, LIME had problems with saying at least anything...
exp_ig_4 = IntegratedGradients(loaded_model)
attr_ig_4 = exp_ig_4.attribute(inputs[4].unsqueeze(0), target=50)
show_attr(attr_ig_4[0])
Still messy IntegratedGradients, however, javeling clearly seen as a neutral surface.
from captum.attr import KernelShap
exp_ks_4_1 = KernelShap(loaded_model)
exp_ks_4_1 = exp_ks_4_1.attribute(
inputs[4].unsqueeze(0),
target=50,
n_samples=500,
feature_mask=torch.as_tensor(mask_4),
show_progress=True
)
show_attr(exp_ks_4_1[0])
Kernel Shap attribution: 100%|██████████| 500/500 [01:23<00:00, 5.99it/s]
Finally, in SHAP method, some interesting results can be seen. From colors of this explanation we can easily assume, that on this picture is a human with javelin. Additionaly, the colors are in such possitions, that we can say, that it makes logical sense.
Three hockey players and someoune, probably referee, between them. Referee makes the picture less readable, as he is wearing dark clothes. Hopefully, even though the model can explain itself, why the prediction is good.
grid_test = torchvision.utils.make_grid(inputs[5], nrow=10)
plt.figure(figsize=(10,10))
plt.imshow(grid_test.cpu().permute(1,2,0))
print(f'true label: {classes[labels[5]]}')
print(f'true label: {classes[pred[5]]}')
print(labels[5])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
true label: hockey true label: hockey tensor(41)
explainer_5= Lime(loaded_model)
mask_5 = segmentation.quickshift(
inputs[5].permute(1, 2, 0).double(),
kernel_size=20,
max_dist=25,
ratio=1.0
)
attr_5 = explainer_5.attribute(
inputs[5].unsqueeze(0),
target=41,
n_samples=20,
feature_mask=torch.as_tensor(mask_5),
show_progress=True
)
np.max(attr_5.tolist())
Lime attribution: 100%|██████████| 20/20 [00:02<00:00, 7.01it/s]
0.0
show_image_mask_explanation(inputs[5], mask_5, attr_5[0].mean(axis=0)*1000000)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Surprisingly, we can see an interesting pattern on LIME explanation, as there seems to be one of the players visible.
from captum.attr import IntegratedGradients
exp_ig_5 = IntegratedGradients(loaded_model)
attr_ig_5 = exp_ig_5.attribute(inputs[5].unsqueeze(0), target=41)
show_attr(attr_ig_5[0])
In this method, still essy, only we can say thet the left bottom corner had more impact on the decidion of a model - of course according to IntegratedGradients explanations.
exp_ks_5_1 = KernelShap(loaded_model)
exp_ks_5_1 = exp_ks_5_1.attribute(
inputs[5].unsqueeze(0),
target=41,
n_samples=500,
feature_mask=torch.as_tensor(mask_5),
show_progress=True
)
show_attr(exp_ks_5_1[0])
Kernel Shap attribution: 100%|██████████| 500/500 [01:15<00:00, 6.61it/s]
Finally SHAP, logically the explanation makes sense, there are two players in green and even one of the hockey sticks. The hockeysticks is the most green thing there, so had biggest good impact, that makes sense.
Being hones, I di not know about this sport before. The picture seems to be readable, especially 'pipe' and helmet. Let's what the explainers said.
grid_test = torchvision.utils.make_grid(inputs[6], nrow=10)
plt.figure(figsize=(10,10))
plt.imshow(grid_test.cpu().permute(1,2,0))
print(f'true label: {classes[labels[6]]}')
print(f'true label: {classes[pred[6]]}')
print(labels[6])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
true label: pole climbing true label: pole climbing tensor(61)
explainer_6= Lime(loaded_model)
mask_6 = segmentation.quickshift(
inputs[6].permute(1, 2, 0).double(),
kernel_size=20,
max_dist=25,
ratio=1.0
)
attr_6 = explainer_6.attribute(
inputs[6].unsqueeze(0),
target=61,
n_samples=20,
feature_mask=torch.as_tensor(mask_6),
show_progress=True
)
np.max(attr_6.tolist())
Lime attribution: 100%|██████████| 20/20 [00:03<00:00, 5.63it/s]
0.0
show_image_mask_explanation(inputs[6], mask_6, attr_6[0].mean(axis=0)*1000000)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
LIME decided to highlight the whole person, which in some way makes sense, because the person seems to be in a not usual position.
from captum.attr import IntegratedGradients
exp_ig_6 = IntegratedGradients(loaded_model)
attr_ig_6 = exp_ig_6.attribute(inputs[6].unsqueeze(0), target=61)
show_attr(attr_ig_6[0])
This time, we can see some shapes, the person and part of a pipe, however, I can say that only because I know the picture. But this time we can see, which regions of a picture were more influencing the decidion than others.
exp_ks_6_1 = KernelShap(loaded_model)
exp_ks_6_1 = exp_ks_6_1.attribute(
inputs[6].unsqueeze(0),
target=61,
n_samples=500,
feature_mask=torch.as_tensor(mask_6),
show_progress=True
)
show_attr(exp_ks_6_1[0])
Kernel Shap attribution: 100%|██████████| 500/500 [01:24<00:00, 5.91it/s]
And SHAP explanation, without big surprises, performed best here, showing what a human would say about most important parts of a picture.
Now I would like to look closer on a picture with snowboard, which for human seems to be easy to recognize. On contary, the model had problems with predicting good label.
grid_test = torchvision.utils.make_grid(inputs[8], nrow=10)
plt.figure(figsize=(10,10))
plt.imshow(grid_test.cpu().permute(1,2,0))
print(f'true label: {classes[labels[8]]}')
print(f'pred label: {classes[pred[8]]}')
print(labels[8])
print(pred[8])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
true label: snow boarding pred label: ice climbing tensor(79) tensor(47)
explainer_2 = Lime(loaded_model)
mask_2 = segmentation.quickshift(
inputs[8].permute(1, 2, 0).double(),
kernel_size=15, #15
max_dist=10,
ratio=0.7
)
mask_2
array([[12, 12, 12, ..., 24, 24, 24],
[12, 12, 12, ..., 24, 24, 24],
[12, 12, 12, ..., 24, 24, 24],
...,
[68, 68, 68, ..., 86, 86, 86],
[68, 68, 68, ..., 86, 86, 86],
[68, 68, 84, ..., 86, 86, 86]], dtype=int64)
attr_2 = explainer_2.attribute(
inputs[8].unsqueeze(0),
target=47,
n_samples=20,
feature_mask=torch.as_tensor(mask_2),
show_progress=True
)
Lime attribution: 100%|██████████| 20/20 [00:02<00:00, 6.73it/s]
np.max(attr_2.tolist())
0.0
show_image_mask_explanation(inputs[8], mask_2, attr_2[0].mean(axis=0)*1000000)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Also this time LIME had problems giving a proof for the decision of a model. Even though the segmentation looks really cool.
show_attr(attr_2[0])
explainer_2_2 = Lime(loaded_model)
attr_2_2 = explainer_2_2.attribute(
inputs[8].unsqueeze(0),
target=47,
n_samples=20,
feature_mask=torch.as_tensor(mask_2),
show_progress=True
)
Lime attribution: 100%|██████████| 20/20 [00:03<00:00, 5.98it/s]
exp_ig_2 = IntegratedGradients(loaded_model)
attr_ig_2 = exp_ig_2.attribute(inputs[8].unsqueeze(0), target=47)
show_attr(attr_ig_2[0])
exp_ig_2 = IntegratedGradients(loaded_model)
attr_ig_2 = exp_ig_2.attribute(inputs[8].unsqueeze(0), target=79)
show_attr(attr_ig_2[0])
Also this time IntegratedGradien gave some messy output, only slightly showing, which regions were having more impact on decision. These regions seems to be similar for both true label and predicted (wrong) label.
#for predicted, wrong label
from captum.attr import KernelShap
exp_ks_2_1 = KernelShap(loaded_model)
exp_ks_2_1 = exp_ks_2_1.attribute(
inputs[8].unsqueeze(0),
target=47,
n_samples=300,
feature_mask=torch.as_tensor(mask_2),
show_progress=True
)
show_attr(exp_ks_2_1[0])
Kernel Shap attribution: 100%|██████████| 300/300 [00:44<00:00, 6.75it/s]
#for correct label, not predicted
from captum.attr import KernelShap
exp_ks_2_1 = KernelShap(loaded_model)
exp_ks_2_1 = exp_ks_2_1.attribute(
inputs[8].unsqueeze(0),
target=79,
n_samples=300,
feature_mask=torch.as_tensor(mask_2),
show_progress=True
)
show_attr(exp_ks_2_1[0])
Kernel Shap attribution: 100%|██████████| 300/300 [00:58<00:00, 5.13it/s]
In this case, the result surprised me the most. Model explains, that the most good part for true label is a board, what makes a lot of sense. However, the amout of parts of a picture, which makes model think, it is not snowboard is too big. In this case, model says, that the predicion is 'ice climbing' which is not correct. However, on the explanation for this idea, model gives quite good resons, that means, part of a hill with snow, in right bottom corner.
It was really interesting to see, which parts of a picture made model make decision about the prediction. A bit disappointing was, that not for all images, LIME method was able to give any ideas. Also, IntegratedGradients was nearly in all cases too messy, to think of any shapes, that might be on a picture, shapes which made model think of a solution. The best method, I like it the most, was SHAP. Nearly in all cases, the explanation were logically acceptable for me, as things that I would also say, when ased, why I think that there is snowboarding person or climbing one. The most surprising was the last one, wrong classified, even though board was correctly found. I hope, that this situation took only place bacause the model had only 75% accuracy on this test batch. I wish training of this model took less (nearly 11h on my computer to calculate only part of a epoch), so I would be able to wait longer and see the explanations of a model ahving over 90% of accuracy (in the source mentioned before, they made it so good).